from mbi import Dataset, Factor, CliqueVector
from scipy.optimize import minimize
from collections import defaultdict
import numpy as np
import jax.numpy as jnp
from jax import vjp
from jax.nn import softmax as jax_softmax
from scipy.special import softmax
from functools import reduce
from scipy.sparse.linalg import lsmr
import pandas as pd

""" This file is experimental.

It is a close approximation to the method described in RAP (https://arxiv.org/abs/2103.06641)
and an even closer approximation to RAP^{softmax} (https://arxiv.org/abs/2106.07153)

Notable differences:
- Code now shares the same interface as Private-PGM (see FactoredInference)
- Named model "MixtureOfProducts", as that is one interpretation for the relaxed tabular format
(at least when softmax is used).
- Added support for unbounded-DP, with automatic estimate of total.
"""


def estimate_total(measurements):
    # find the minimum variance estimate of the total given the measurements
    variances = np.array([])
    estimates = np.array([])
    for Q, y, noise, proj in measurements:
        o = np.ones(Q.shape[1])
        v = lsmr(Q.T, o, atol=0, btol=0)[0]
        if np.allclose(Q.T.dot(v), o):
            variances = np.append(variances, noise**2 * np.dot(v, v))
            estimates = np.append(estimates, np.dot(v, y))
    if estimates.size == 0:
        return 1
    else:
        variance = 1.0 / np.sum(1.0 / variances)
        estimate = variance * np.sum(estimates / variances)
        return max(1, estimate)

def adam(loss_and_grad, x0, iters=250):
    a = 1.0
    b1, b2 = 0.9, 0.999
    eps = 10e-8

    x = x0
    m = np.zeros_like(x)
    v = np.zeros_like(x)
    for t in range(1, iters+1):
        l, g = loss_and_grad(x)
        #print(l)
        m = b1 * m + (1- b1) * g
        v = b2 * v + (1 - b2) * g**2
        mhat = m / (1 - b1**t)
        vhat = v / (1 - b2**t)
        x = x - a * mhat / (np.sqrt(vhat) + eps)
#        print np.linalg.norm(A.dot(x) - y, ord=2)
    return x

def synthetic_col(counts, total):
    counts *= total / counts.sum()
    frac, integ = np.modf(counts)
    integ = integ.astype(int)
    extra = total - integ.sum()
    if extra > 0:
        idx = np.random.choice(counts.size, extra, False, frac / frac.sum())
        integ[idx] += 1
    vals = np.repeat(np.arange(counts.size), integ)
    np.random.shuffle(vals)
    return vals

class MixtureOfProducts:
    def __init__(self, products, domain, total):
        self.products = products
        self.domain = domain
        self.total = total 
        self.num_components = next(iter(products.values())).shape[0]

    def project(self, cols):
        products = { col : self.products[col] for col in cols }
        domain = self.domain.project(cols)
        return MixtureOfProducts(products, domain, self.total)
    
    def datavector(self, flatten=True):
        letters = 'bcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'[:len(self.domain)]
        formula = ','.join(['a%s'%l for l in letters]) + '->' + ''.join(letters)
        components = [self.products[col] for col in self.domain]
        ans = np.einsum(formula, *components) * self.total / self.num_components
        return ans.flatten() if flatten else ans

    def synthetic_data(self, rows=None):
        total = rows or int(self.total)
        subtotal = total // self.num_components + 1
       
        dfs = []
        for i in range(self.num_components): 
            df = pd.DataFrame()
            for col in self.products:
                counts = self.products[col][i]
                df[col] = synthetic_col(counts, subtotal)
            dfs.append(df)

        df = pd.concat(dfs).sample(frac=1).reset_index(drop=True)[:total]
        return Dataset(df, self.domain)

class MixtureInference:
    def __init__(self, domain, components=10, metric='L2', iters=2500, warm_start=False):
        """
        :param domain: A Domain object
        :param components: The number of mixture components
        :metric: The metric to use for the loss function (can be callable)
        """
        self.domain = domain
        self.components = components
        self.metric = metric
        self.iters = iters
        self.warm_start = warm_start
        self.params = np.random.normal(loc=0, scale=0.25, size=sum(domain.shape) * components)

    def estimate(self, measurements, total=None, alpha=0.1):
        if total == None:
            total = estimate_total(measurements)
        self.measurements = measurements
        cliques = [M[-1] for M in measurements]
        letters = 'bcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'

        def get_products(params):
            products = {}
            idx = 0
            for col in self.domain:
                n = self.domain[col]
                k = self.components
                products[col] = jax_softmax(params[idx:idx+k*n].reshape(k,n), axis=1)
                idx += k*n
            return products
            
        def marginals_from_params(params):
            products = get_products(params)
            mu = {}
            for cl in cliques:
                let = letters[:len(cl)]
                formula = ','.join(['a%s'%l for l in let]) + '->' + ''.join(let)
                components = [products[col] for col in cl]
                ans = jnp.einsum(formula, *components) * total / self.components
                mu[cl] = ans.flatten()
            return mu

        def loss_and_grad(params):
            # For computing dL / dmu we will use ordinary numpy so as to support scipy sparse and linear operator inputs
            # For computing dL / dparams we will use jax to avoid manually deriving gradients
            params = jnp.array(params)
            mu, backprop = vjp(marginals_from_params, params)
            mu = { cl : np.array(mu[cl]) for cl in cliques }
            loss, dL = self._marginal_loss(mu)
            dL = { cl : jnp.array(dL[cl]) for cl in cliques }
            dparams = backprop(dL)
            return loss, np.array(dparams[0])
          
        if not self.warm_start:
            self.params = np.random.normal(loc=0, scale=0.25, size=sum(self.domain.shape) * self.components)
        self.params = adam(loss_and_grad, self.params, iters=self.iters)
        products = get_products(self.params)
        return MixtureOfProducts(products, self.domain, total)
    
    def _marginal_loss(self, marginals, metric=None):
        """ Compute the loss and gradient for a given dictionary of marginals

        :param marginals: A dictionary with keys as projections and values as Factors
        :return loss: the loss value
        :return grad: A dictionary with gradient for each marginal 
        """
        if metric is None:
            metric = self.metric

        loss = 0.0
        gradient = { cl : np.zeros_like(marginals[cl]) for cl in marginals }

        for Q, y, noise, cl in self.measurements:
            x = marginals[cl]
            c = 1.0/noise
            diff = c*(Q @ x - y)
            if metric == 'L1':
                loss += abs(diff).sum()
                sign = diff.sign() if hasattr(diff, 'sign') else np.sign(diff)
                grad = c*(Q.T @ sign)
            else:
                loss += 0.5*(diff @ diff)
                grad = c*(Q.T @ diff)
            gradient[cl] += grad

        return float(loss), gradient
